# Bayesian network analysis of wt/PiZ mutant mice liver expression (RNA-seq) data using bnlearn
# input: n x (m+1) matrix of discretized gene expression values and treatment status (wt vs piz) for zinc transporter genes
# output: network structure and fitted Bayesian model

gse <- "GSE141593"
dir <- "./analysis/bnlearn/"

# libraries and functions -------------------------------------------------
library(bnlearn)
library(Rgraphviz)


# load data and factor column values ----------------------------------------------------------------------
load(paste0(dir, gse, "_zn_transporter_gene_expr.RData"))
sample.id <- rownames(pizmice.counts)
pizmice.counts <- data.frame(lapply(pizmice.counts, FUN = function(x) factor(x)))
str(pizmice.counts)

rownames(pizmice.counts) <- sample.id
rm(sample.id)

# generate structure from data --------------------------------------------

# # constraint-based algorithms
# pzm.gs <- gs(pizmice.counts) #grow-shrink  
# pzm.gs
# pzm.iamb <- iamb(pizmice.counts) 
# pzm.iamb

# score-based algorithms
pzm.hc <- hc(pizmice.counts, restart = 500)
pzm.hc

# # compare network scores; 
# # the constraint based structures are only partially directed, using cextend()
# score(cextend(pzm.gs), data = pizmice.counts, type = "bic")
# score(cextend(pzm.iamb), data = pizmice.counts, type = "bic")
# score(pzm.hc, data = pizmice.counts, type = "bic")

# all.equal(cpdag(pzm.hc), cpdag(cextend(pzm.iamb)))


# plot network ------------------------------------------------------------
zn.tp.genes <- read.table("./data/zn_transporter_genes.txt", sep = "\n")$V1
zn.tp.genes <- zn.tp.genes[zn.tp.genes %in% nodes(pzm.hc)]
graphviz.plot(pzm.hc, layout = "dot", highlight = list(nodes = zn.tp.genes, fill = "orange"), fontsize = 20)
# graphviz.plot(pzm.iamb)

# estimate parameters -----------------------------------------------------
# using hill-climbing results for DAG
rm(pzm.gs, pzm.iamb)

pz.mle.fit <- bn.fit(pzm.hc, data = pizmice.counts, method = "mle")
pz.mle.fit
pz.mle.fit$pizMutation


pz.bayes.fit <- bn.fit(pzm.hc, data = pizmice.counts, method = "bayes")
pz.bayes.fit
pz.bayes.fit$pizMutation
